# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flags which will be nearly universal across models."""

from absl import flags
import tensorflow as tf
from ._conventions import help_wrap


def define_base(data_dir=True,
                model_dir=True,
                clean=False,
                train_epochs=False,
                epochs_between_evals=False,
                stop_threshold=False,
                batch_size=True,
                num_gpu=False,
                hooks=False,
                export_dir=False,
                distribution_strategy=False,
                run_eagerly=False):
    """Register base flags.

  Args:
    data_dir: Create a flag for specifying the input data directory.
    model_dir: Create a flag for specifying the model file directory.
    clean: Create a flag for removing the model_dir.
    train_epochs: Create a flag to specify the number of training epochs.
    epochs_between_evals: Create a flag to specify the frequency of testing.
    stop_threshold: Create a flag to specify a threshold accuracy or other eval
      metric which should trigger the end of training.
    batch_size: Create a flag to specify the batch size.
    num_gpu: Create a flag to specify the number of GPUs used.
    hooks: Create a flag to specify hooks for logging.
    export_dir: Create a flag to specify where a SavedModel should be exported.
    distribution_strategy: Create a flag to specify which Distribution Strategy
      to use.
    run_eagerly: Create a flag to specify to run eagerly op by op.

  Returns:
    A list of flags for core.py to marks as key flags.
  """
    key_flags = []

    if data_dir:
        flags.DEFINE_string(name="data_dir",
                            short_name="dd",
                            default="/tmp",
                            help=help_wrap("The location of the input data."))
        key_flags.append("data_dir")

    if model_dir:
        flags.DEFINE_string(
            name="model_dir",
            short_name="md",
            default="/tmp",
            help=help_wrap("The location of the model checkpoint files."))
        key_flags.append("model_dir")

    if clean:
        flags.DEFINE_boolean(
            name="clean",
            default=False,
            help=help_wrap("If set, model_dir will be removed if it exists."))
        key_flags.append("clean")

    if train_epochs:
        flags.DEFINE_integer(
            name="train_epochs",
            short_name="te",
            default=1,
            help=help_wrap("The number of epochs used to train."))
        key_flags.append("train_epochs")

    if epochs_between_evals:
        flags.DEFINE_integer(
            name="epochs_between_evals",
            short_name="ebe",
            default=1,
            help=help_wrap("The number of training epochs to run between "
                           "evaluations."))
        key_flags.append("epochs_between_evals")

    if stop_threshold:
        flags.DEFINE_float(
            name="stop_threshold",
            short_name="st",
            default=None,
            help=help_wrap("If passed, training will stop at the earlier of "
                           "train_epochs and when the evaluation metric is  "
                           "greater than or equal to stop_threshold."))

    if batch_size:
        flags.DEFINE_integer(
            name="batch_size",
            short_name="bs",
            default=32,
            help=help_wrap(
                "Batch size for training and evaluation. When using "
                "multiple gpus, this is the global batch size for "
                "all devices. For example, if the batch size is 32 "
                "and there are 4 GPUs, each GPU will get 8 examples on "
                "each step."))
        key_flags.append("batch_size")

    if num_gpu:
        flags.DEFINE_integer(
            name="num_gpus",
            short_name="ng",
            default=1,
            help=help_wrap("How many GPUs to use at each worker with the "
                           "DistributionStrategies API. The default is 1."))

    if run_eagerly:
        flags.DEFINE_boolean(
            name="run_eagerly",
            default=False,
            help="Run the model op by op without building a model function.")

    if hooks:
        flags.DEFINE_list(
            name="hooks",
            short_name="hk",
            default="LoggingTensorHook",
            help=help_wrap(
                u"A list of (case insensitive) strings to specify the names of "
                u"training hooks. Example: `--hooks ProfilerHook,"
                u"ExamplesPerSecondHook`\n See hooks_helper "
                u"for details."))
        key_flags.append("hooks")

    if export_dir:
        flags.DEFINE_string(
            name="export_dir",
            short_name="ed",
            default=None,
            help=help_wrap(
                "If set, a SavedModel serialization of the model will "
                "be exported to this directory at the end of training. "
                "See the README for more details and relevant links."))
        key_flags.append("export_dir")

    if distribution_strategy:
        flags.DEFINE_string(
            name="distribution_strategy",
            short_name="ds",
            default="mirrored",
            help=help_wrap("The Distribution Strategy to use for training. "
                           "Accepted values are 'off', 'one_device', "
                           "'mirrored', 'parameter_server', 'collective', "
                           "case insensitive. 'off' means not to use "
                           "Distribution Strategy; 'default' means to choose "
                           "from `MirroredStrategy` or `OneDeviceStrategy` "
                           "according to the number of GPUs."))

    return key_flags


def get_num_gpus(flags_obj):
    """Treat num_gpus=-1 as 'use all'."""
    if flags_obj.num_gpus != -1:
        return flags_obj.num_gpus

    from tensorflow.python.client import device_lib  # pylint: disable=g-import-not-at-top
    local_device_protos = device_lib.list_local_devices()
    return sum([1 for d in local_device_protos if d.device_type == "GPU"])
